{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Patch Time Series Transformer in HuggingFace - Getting Started\n",
    "\n",
    "In this blog, we provide examples of how to get started with PatchTST. We first demonstrate the forecasting capability of `PatchTST` on the Electricity data. We will then demonstrate the transfer learning capability of `PatchTST` by using the previously trained model to do zero-shot forecasting on the electrical transformer (ETTh1) dataset. The zero-shot forecasting performance will denote the `test` performance of the model in the `target` domain, without any  training on the target domain. Subsequently, we will do linear probing and (then) finetuning of the pretrained model on the `train` part of the target data and will validate the forecasting performance on the `test` part of the target data.\n",
    "\n",
    "The `PatchTST` model was proposed in A Time Series is Worth [64 Words: Long-term Forecasting with Transformers](https://huggingface.co/papers/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam and presented at ICLR 2023.\n",
    "\n",
    "\n",
    "## Quick overview of PatchTST\n",
    "\n",
    "At a high level, the model vectorizes individual time series in a batch into patches of a given size and encodes the resulting sequence of vectors via a Transformer that then outputs the prediction length forecast via an appropriate head.\n",
    "\n",
    "The model is based on two key components: \n",
    "  1. segmentation of time series into subseries-level patches which serve as input tokens to the Transformer; \n",
    "  2.  channel-independence where each channel contains a single univariate time series that shares the same embedding and Transformer weights across all the series, i.e. a [global](https://doi.org/10.1016/j.ijforecast.2021.03.004) univariate model. \n",
    "\n",
    "The patching design naturally has three-fold benefit: \n",
    " - local semantic information is retained in the embedding; \n",
    " - computation and memory usage of the attention maps are quadratically reduced given the same look-back window via strides between patches; and \n",
    " - the model can attend longer history via a trade-off between the patch length (input vector size) and the context length (number of sequences).\n",
    " \n",
    "\n",
    "In addition, `PatchTST` has a modular design to seamlessly support masked time series pre-training as well as direct time series forecasting.\n",
    "\n",
    "| ![PatchTST model schematics](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/patchtst/patchtst-arch.png) |\n",
    "|:--:|\n",
    "|(a) `PatchTST` model overview where a batch of \\\\(M\\\\) time series each of length \\\\(L\\\\) are processed independently (by reshaping them into the batch dimension) via a Transformer backbone and then reshaping the resulting batch back into \\\\(M \\\\) series of prediction length \\\\(T\\\\). Each *univariate* series can be processed in a supervised fashion (b) where the patched set of vectors is used to output the full prediction length or in a self-supervised fashion (c) where masked patches are predicted. |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Installation\n",
    "\n",
    "This demo requires Hugging Face [`Transformers`](https://github.com/huggingface/transformers) for the model, and the IBM `tsfm` package for auxiliary data pre-processing.\n",
    "We can install both by cloning the `tsfm` repository and following the below steps.\n",
    "\n",
    "1. Clone the public IBM Time Series Foundation Model Repository [`tsfm`](https://github.com/ibm/tsfm).\n",
    "    ```bash\n",
    "    pip install git+https://github.com/IBM/tsfm.git\n",
    "    ```\n",
    "2. Install Hugging Face [`Transformers`](https://github.com/huggingface/transformers#installation)\n",
    "    ```bash\n",
    "    pip install transformers\n",
    "    ```\n",
    "3. Test it with the following commands in a `python` terminal.\n",
    "    ```python\n",
    "    from transformers import PatchTSTConfig\n",
    "    from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
    "    ```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 1: Forecasting on the Electricity dataset\n",
    "Here we train a `PatchTST` model directly on the Electricity dataset (available from https://github.com/zhouhaoyi/Informer2020), and evaluate its performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard\n",
    "import os\n",
    "\n",
    "# Third Party\n",
    "from transformers import (\n",
    "    EarlyStoppingCallback,\n",
    "    PatchTSTConfig,\n",
    "    PatchTSTForPrediction,\n",
    "    set_seed,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# First Party\n",
    "from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
    "from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor\n",
    "from tsfm_public.toolkit.util import select_by_index\n",
    "\n",
    "# supress some warnings\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\", module=\"torch\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ### Set seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(2023)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load and prepare datasets\n",
    "\n",
    " In the next cell, please adjust the following parameters to suit your application:\n",
    " - `dataset_path`: path to local .csv file, or web address to a csv file for the data of interest. Data is loaded with pandas, so anything supported by\n",
    "   `pd.read_csv` is supported: (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html).\n",
    " - `timestamp_column`: column name containing timestamp information, use `None` if there is no such column.\n",
    " - `id_columns`: List of column names specifying the IDs of different time series. If no ID column exists, use `[]`.\n",
    " - `forecast_columns`: List of columns to be modeled\n",
    " - `context_length`: The amount of historical data used as input to the model. Windows of the input time series data with length equal to `context_length` will be extracted from the input dataframe. In the case of a multi-time series dataset, the context windows will be created so that they are contained within a single time series (i.e., a single ID).\n",
    " - `forecast_horizon`: Number of timestamps to forecast in the future.\n",
    " - `train_start_index`, `train_end_index`: the start and end indices in the loaded data which delineate the training data.\n",
    " - `valid_start_index`, `eval_end_index`: the start and end indices in the loaded data which delineate the validation data.\n",
    " - `test_start_index`, `eval_end_index`: the start and end indices in the loaded data which delineate the test data.\n",
    " - `patch_length`: The patch length for the `PatchTST` model. It is recommended to choose a value that evenly divides `context_length`.\n",
    " - `num_workers`: Number of CPU workers in the PyTorch dataloader.\n",
    " - `batch_size`: Batch size.\n",
    "\n",
    "The data is first loaded into a Pandas dataframe and split into training, validation, and test parts. Then the Pandas dataframes are converted to the appropriate PyTorch dataset required for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The ECL data is available from https://github.com/zhouhaoyi/Informer2020?tab=readme-ov-file#data\n",
    "dataset_path = \"~/data/ECL.csv\"\n",
    "timestamp_column = \"date\"\n",
    "id_columns = []\n",
    "\n",
    "context_length = 512\n",
    "forecast_horizon = 96\n",
    "patch_length = 16\n",
    "num_workers = 16  # Reduce this if you have low number of CPU cores\n",
    "batch_size = 64  # Adjust according to GPU memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(\n",
    "    dataset_path,\n",
    "    parse_dates=[timestamp_column],\n",
    ")\n",
    "forecast_columns = list(data.columns[1:])\n",
    "\n",
    "# get split\n",
    "num_train = int(len(data) * 0.7)\n",
    "num_test = int(len(data) * 0.2)\n",
    "num_valid = len(data) - num_train - num_test\n",
    "border1s = [\n",
    "    0,\n",
    "    num_train - context_length,\n",
    "    len(data) - num_test - context_length,\n",
    "]\n",
    "border2s = [num_train, num_train + num_valid, len(data)]\n",
    "\n",
    "train_start_index = border1s[0]  # None indicates beginning of dataset\n",
    "train_end_index = border2s[0]\n",
    "\n",
    "# we shift the start of the evaluation period back by context length so that\n",
    "# the first evaluation timestamp is immediately following the training data\n",
    "valid_start_index = border1s[1]\n",
    "valid_end_index = border2s[1]\n",
    "\n",
    "test_start_index = border1s[2]\n",
    "test_end_index = border2s[2]\n",
    "\n",
    "train_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=train_start_index,\n",
    "    end_index=train_end_index,\n",
    ")\n",
    "valid_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=valid_start_index,\n",
    "    end_index=valid_end_index,\n",
    ")\n",
    "test_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=test_start_index,\n",
    "    end_index=test_end_index,\n",
    ")\n",
    "\n",
    "time_series_preprocessor = TimeSeriesPreprocessor(\n",
    "    timestamp_column=timestamp_column,\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    scaling=True,\n",
    ")\n",
    "time_series_preprocessor = time_series_preprocessor.train(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = ForecastDFDataset(\n",
    "    time_series_preprocessor.preprocess(train_data),\n",
    "    id_columns=id_columns,\n",
    "    timestamp_column=\"date\",\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")\n",
    "valid_dataset = ForecastDFDataset(\n",
    "    time_series_preprocessor.preprocess(valid_data),\n",
    "    id_columns=id_columns,\n",
    "    timestamp_column=\"date\",\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")\n",
    "test_dataset = ForecastDFDataset(\n",
    "    time_series_preprocessor.preprocess(test_data),\n",
    "    id_columns=id_columns,\n",
    "    timestamp_column=\"date\",\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configure the PatchTST model\n",
    "\n",
    "Next, we instantiate a randomly initialized `PatchTST` model with a configuration. The settings below control the different hyperparameters related to the architecture.\n",
    "  - `num_input_channels`: the number of input channels (or dimensions) in the time series data. This is\n",
    "    automatically set to the number for forecast columns.\n",
    "  - `context_length`: As described above, the amount of historical data used as input to the model.\n",
    "  - `patch_length`: The length of the patches extracted from the context window (of length `context_length`).\n",
    "  - `patch_stride`: The stride used when extracting patches from the context window.\n",
    "  - `random_mask_ratio`: The fraction of input patches that are completely masked for pretraining the model.\n",
    "  - `d_model`: Dimension of the transformer layers.\n",
    "  - `num_attention_heads`: The number of attention heads for each attention layer in the Transformer encoder.\n",
    "  - `num_hidden_layers`: The number of encoder layers.\n",
    "  - `ffn_dim`: Dimension of the intermediate (often referred to as feed-forward) layer in the encoder.\n",
    "  - `dropout`: Dropout probability for all fully connected layers in the encoder.\n",
    "  - `head_dropout`: Dropout probability used in the head of the model.\n",
    "  - `pooling_type`: Pooling of the embedding. `\"mean\"`, `\"max\"` and `None` are supported.\n",
    "  - `channel_attention`: Activate the channel attention block in the Transformer to allow channels to attend to each other.\n",
    "  - `scaling`: Whether to scale the input targets via \"mean\" scaler, \"std\" scaler, or no scaler if `None`. If `True`, the\n",
    "    scaler is set to `\"mean\"`.\n",
    "  - `loss`: The loss function for the model corresponding to the `distribution_output` head. For parametric\n",
    "    distributions it is the negative log-likelihood (`\"nll\"`) and for point estimates it is the mean squared\n",
    "    error `\"mse\"`.\n",
    "  - `pre_norm`: Normalization is applied before self-attention if pre_norm is set to `True`. Otherwise, normalization is\n",
    "    applied after residual block.\n",
    "  - `norm_type`: Normalization at each Transformer layer. Can be `\"BatchNorm\"` or `\"LayerNorm\"`.\n",
    "\n",
    "For full details on the parameters, we refer to the [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/patchtst#transformers.PatchTSTConfig).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = PatchTSTConfig(\n",
    "    num_input_channels=len(forecast_columns),\n",
    "    context_length=context_length,\n",
    "    patch_length=patch_length,\n",
    "    patch_stride=patch_length,\n",
    "    prediction_length=forecast_horizon,\n",
    "    random_mask_ratio=0.4,\n",
    "    d_model=128,\n",
    "    num_attention_heads=16,\n",
    "    num_hidden_layers=3,\n",
    "    ffn_dim=256,\n",
    "    dropout=0.2,\n",
    "    head_dropout=0.2,\n",
    "    pooling_type=None,\n",
    "    channel_attention=False,\n",
    "    scaling=\"std\",\n",
    "    loss=\"mse\",\n",
    "    pre_norm=True,\n",
    "    norm_type=\"batchnorm\",\n",
    ")\n",
    "model = PatchTSTForPrediction(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train model\n",
    "\n",
    "Next, we can leverage the Hugging Face [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) class to train the model based on the direct forecasting strategy. We first define the [TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) which lists various hyperparameters for training such as the number of epochs, learning rate and so on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./checkpoint/patchtst/electricity/pretrain/output/\",\n",
    "    overwrite_output_dir=True,\n",
    "    # learning_rate=0.001,\n",
    "    num_train_epochs=100,\n",
    "    do_eval=True,\n",
    "    eval_strategy=\"epoch\",\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    dataloader_num_workers=num_workers,\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    save_total_limit=3,\n",
    "    logging_dir=\"./checkpoint/patchtst/electricity/pretrain/logs/\",  # Make sure to specify a logging directory\n",
    "    load_best_model_at_end=True,  # Load the best model when training ends\n",
    "    metric_for_best_model=\"eval_loss\",  # Metric to monitor for early stopping\n",
    "    greater_is_better=False,  # For loss\n",
    "    label_names=[\"future_values\"],\n",
    ")\n",
    "\n",
    "# Create the early stopping callback\n",
    "early_stopping_callback = EarlyStoppingCallback(\n",
    "    early_stopping_patience=10,  # Number of epochs with no improvement after which to stop\n",
    "    early_stopping_threshold=0.0001,  # Minimum improvement required to consider as improvement\n",
    ")\n",
    "\n",
    "# define trainer\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    callbacks=[early_stopping_callback],\n",
    "    # compute_metrics=compute_metrics,\n",
    ")\n",
    "\n",
    "# pretrain\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate the model on the test set of the source domain\n",
    "\n",
    "Next, we can leverage `trainer.evaluate()` to calculate test metrics. While this is not the target metric to judge in this task, it provides a reasonable check that the pretrained model has trained properly.\n",
    "Note that the training and evaluation loss for `PatchTST` is the Mean Squared Error (MSE) loss. Hence, we do not separately compute the MSE metric in any of the following evaluation experiments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = trainer.evaluate(test_dataset)\n",
    "print(\"Test result:\")\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The MSE of `0.131` is very close to the value reported for the Electricity dataset in the original `PatchTST` paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = \"patchtst/electricity/model/pretrain/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "trainer.save_model(save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 2: Transfer Learning from Electricity to ETTh1\n",
    "\n",
    "\n",
    "In this section, we will demonstrate the transfer learning capability of the `PatchTST` model.\n",
    "We use the model pre-trained on the Electricity dataset to do zero-shot forecasting on the ETTh1 dataset.\n",
    "\n",
    "\n",
    "By Transfer Learning, we mean that we first pretrain the model for a forecasting task on a `source` dataset (which we did above on the `Electricity` dataset). Then, we will use the pretrained model for zero-shot forecasting on a `target` dataset. By zero-shot, we mean that we test the performance in the `target` domain without any additional training. We hope that the model gained enough knowledge from pretraining which can be transferred to a different dataset. \n",
    "Subsequently, we will do linear probing and (then) finetuning of the pretrained model on the `train` split of the target data and will validate the forecasting performance on the `test` split of the target data. In this example, the source dataset is the `Electricity` dataset and the target dataset is ETTh1.\n",
    "\n",
    "### Transfer learning on ETTh1 data. \n",
    "All evaluations are on the `test` part of the `ETTh1` data.\n",
    "\n",
    "Step 1: Directly evaluate the electricity-pretrained model. This is the zero-shot performance. \n",
    "\n",
    "Step 2: Evaluate after doing linear probing. \n",
    "\n",
    "Step 3: Evaluate after doing full finetuning. \n",
    "\n",
    "### Load ETTh dataset\n",
    "\n",
    "Below, we load the `ETTh1` dataset as a Pandas dataframe. Next, we create 3 splits for training, validation, and testing. We then leverage the `TimeSeriesPreprocessor` class to prepare each split for the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"ETTh1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Loading target dataset: {dataset}\")\n",
    "dataset_path = f\"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset}.csv\"\n",
    "timestamp_column = \"date\"\n",
    "id_columns = []\n",
    "forecast_columns = [\"HUFL\", \"HULL\", \"MUFL\", \"MULL\", \"LUFL\", \"LULL\", \"OT\"]\n",
    "train_start_index = None  # None indicates beginning of dataset\n",
    "train_end_index = 12 * 30 * 24\n",
    "\n",
    "# we shift the start of the evaluation period back by context length so that\n",
    "# the first evaluation timestamp is immediately following the training data\n",
    "valid_start_index = 12 * 30 * 24 - context_length\n",
    "valid_end_index = 12 * 30 * 24 + 4 * 30 * 24\n",
    "\n",
    "test_start_index = 12 * 30 * 24 + 4 * 30 * 24 - context_length\n",
    "test_end_index = 12 * 30 * 24 + 8 * 30 * 24"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(\n",
    "    dataset_path,\n",
    "    parse_dates=[timestamp_column],\n",
    ")\n",
    "\n",
    "train_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=train_start_index,\n",
    "    end_index=train_end_index,\n",
    ")\n",
    "valid_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=valid_start_index,\n",
    "    end_index=valid_end_index,\n",
    ")\n",
    "test_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=test_start_index,\n",
    "    end_index=test_end_index,\n",
    ")\n",
    "\n",
    "time_series_preprocessor = TimeSeriesPreprocessor(\n",
    "    timestamp_column=timestamp_column,\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    scaling=True,\n",
    ")\n",
    "time_series_preprocessor = time_series_preprocessor.train(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = ForecastDFDataset(\n",
    "    time_series_preprocessor.preprocess(train_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")\n",
    "valid_dataset = ForecastDFDataset(\n",
    "    time_series_preprocessor.preprocess(valid_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")\n",
    "test_dataset = ForecastDFDataset(\n",
    "    time_series_preprocessor.preprocess(test_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Zero-shot forecasting on ETTH\n",
    "\n",
    "As we are going to test forecasting performance out-of-the-box, we load the model which we pretrained above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_forecast_model = PatchTSTForPrediction.from_pretrained(\n",
    "    \"patchtst/electricity/model/pretrain/\",\n",
    "    num_input_channels=len(forecast_columns),\n",
    "    head_dropout=0.7,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_forecast_args = TrainingArguments(\n",
    "    output_dir=\"./checkpoint/patchtst/transfer/finetune/output/\",\n",
    "    overwrite_output_dir=True,\n",
    "    learning_rate=0.0001,\n",
    "    num_train_epochs=100,\n",
    "    do_eval=True,\n",
    "    eval_strategy=\"epoch\",\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    dataloader_num_workers=num_workers,\n",
    "    report_to=\"tensorboard\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    save_total_limit=3,\n",
    "    logging_dir=\"./checkpoint/patchtst/transfer/finetune/logs/\",  # Make sure to specify a logging directory\n",
    "    load_best_model_at_end=True,  # Load the best model when training ends\n",
    "    metric_for_best_model=\"eval_loss\",  # Metric to monitor for early stopping\n",
    "    greater_is_better=False,  # For loss\n",
    "    label_names=[\"future_values\"],\n",
    ")\n",
    "\n",
    "# Create a new early stopping callback with faster convergence properties\n",
    "early_stopping_callback = EarlyStoppingCallback(\n",
    "    early_stopping_patience=10,  # Number of epochs with no improvement after which to stop\n",
    "    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement\n",
    ")\n",
    "\n",
    "finetune_forecast_trainer = Trainer(\n",
    "    model=finetune_forecast_model,\n",
    "    args=finetune_forecast_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    callbacks=[early_stopping_callback],\n",
    ")\n",
    "\n",
    "print(\"\\n\\nDoing zero-shot forecasting on target data\")\n",
    "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
    "print(\"Target data zero-shot forecasting result:\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As can be seen, with a zero-shot forecasting approach we obtain an MSE of 0.370 which is near to the state-of-the-art result in the original `PatchTST` paper.\n",
    "\n",
    "Next, let's see how we can do by performing linear probing, which involves training a linear layer on top of a frozen pre-trained model. Linear probing is often done to test the performance of features of a pretrained model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Linear probing on ETTh1\n",
    "\n",
    "We can do a quick linear probing on the `train` part of the target data to see any possible `test` performance improvement."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Freeze the backbone of the model\n",
    "for param in finetune_forecast_trainer.model.model.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "print(\"\\n\\nLinear probing on the target data\")\n",
    "finetune_forecast_trainer.train()\n",
    "print(\"Evaluating\")\n",
    "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
    "print(\"Target data head/linear probing result:\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As can be seen, by only training a simple linear layer on top of the frozen backbone, the MSE decreased from 0.370 to 0.357, beating the originally reported results!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = f\"patchtst/electricity/model/transfer/{dataset}/model/linear_probe/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "finetune_forecast_trainer.save_model(save_dir)\n",
    "\n",
    "save_dir = f\"patchtst/electricity/model/transfer/{dataset}/preprocessor/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "time_series_preprocessor.save_pretrained(save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's see if we can get additional improvements by doing a full fine-tune of the model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Full fine-tune on ETTh1\n",
    "\n",
    "We can do a full model fine-tune (instead of probing the last linear layer as shown above) on the `train` part of the target data to see a possible `test` performance improvement. The code looks similar to the linear probing task above, except that we are not freezing any parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reload the model\n",
    "finetune_forecast_model = PatchTSTForPrediction.from_pretrained(\n",
    "    \"patchtst/electricity/model/pretrain/\",\n",
    "    num_input_channels=len(forecast_columns),\n",
    "    dropout=0.7,\n",
    "    head_dropout=0.7,\n",
    ")\n",
    "finetune_forecast_trainer = Trainer(\n",
    "    model=finetune_forecast_model,\n",
    "    args=finetune_forecast_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    callbacks=[early_stopping_callback],\n",
    ")\n",
    "print(\"\\n\\nFinetuning on the target data\")\n",
    "finetune_forecast_trainer.train()\n",
    "print(\"Evaluating\")\n",
    "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
    "print(\"Target data full finetune result:\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case, there is only a small improvement on the ETTh1 dataset with full fine-tuning. For other datasets there may be more substantial improvements. Let's save the model anyway."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = f\"patchtst/electricity/model/transfer/{dataset}/model/fine_tuning/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "finetune_forecast_trainer.save_model(save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "In this blog, we presented a step-by-step guide on training `PatchTST` for tasks related to forecasting and transfer learning, demonstrating various approaches for fine-tuning. We intend to facilitate easy integration of the `PatchTST` HF model for your forecasting use cases, and we hope that this content serves as a useful resource to expedite the adoption of PatchTST. Thank you for tuning in to our blog, and we hope you find this information beneficial for your projects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
